# model of EG-VQ-VAE and Audio/text based VQMOtionClassifier

import torch
import torch.nn as nn
import os
import math
import pickle
import numpy as np
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class Quantizer(nn.Module):
    def __init__(self, n_e, e_dim, beta):
        super(Quantizer, self).__init__()

        self.e_dim = e_dim
        self.n_e = n_e
        self.beta = beta

        self.embedding = nn.Embedding(self.n_e, self.e_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

    def forward(self, z):
        """
        Inputs the output of the encoder network z and maps it to a discrete
        one-hot vectort that is the index of the closest embedding vector e_j
        z (continuous) -> z_q (discrete)
        :param z (B, seq_len, channel):
        :return z_q:
        """
        assert z.shape[-1] == self.e_dim
        z_flattened = z.contiguous().view(-1, self.e_dim)

        # B x V
        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - 2 * \
            torch.matmul(z_flattened, self.embedding.weight.t())
        # B x 1
        min_encoding_indices = torch.argmin(d, dim=1)
        z_q = self.embedding(min_encoding_indices).view(z.shape)

        # compute loss for embedding
        loss = torch.mean((z_q - z.detach())**2) + self.beta * \
               torch.mean((z_q.detach() - z)**2)

        # preserve gradients
        z_q = z + (z_q - z).detach()

        min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype)
        e_mean = torch.mean(min_encodings, dim=0)
        perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10)))
        return loss, z_q, min_encoding_indices, perplexity

    def map2index(self, z):
        """
        Inputs the output of the encoder network z and maps it to a discrete
        one-hot vectort that is the index of the closest embedding vector e_j
        z (continuous) -> z_q (discrete)
        :param z (B, seq_len, channel):
        :return z_q:
        """
        assert z.shape[-1] == self.e_dim
        #print(z.shape)
        z_flattened = z.contiguous().view(-1, self.e_dim)
        #print(z_flattened.shape)

        # B x V
        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - 2 * \
            torch.matmul(z_flattened, self.embedding.weight.t())
        # B x 1
        min_encoding_indices = torch.argmin(d, dim=1)
        return min_encoding_indices.reshape(z.shape[0], -1)

    def get_codebook_entry(self, indices):
        """

        :param indices(B, seq_len):
        :return z_q(B, seq_len, e_dim):
        """
        index_flattened = indices.view(-1)
        z_q = self.embedding(index_flattened)
        z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous()
        return z_q

def init_weight(m):
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
        nn.init.xavier_normal_(m.weight)
        # m.bias.data.fill_(0.01)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
            
class ResBlock(nn.Module):
    def __init__(self, channel):
        super(ResBlock, self).__init__()
        self.model = nn.Sequential(
            nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        residual = x
        out = self.model(x)
        out += residual
        return out
    
class VQEncoder(nn.Module):
    def __init__(self, args):
        super(VQEncoder, self).__init__()
        channels = [128, args.vae_length]
        n_down = args.vae_layer
        input_size = args.vae_test_dim
        assert len(channels) == n_down
        layers = [
            nn.Conv1d(input_size, channels[0], 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0]),
        ]
        for i in range(1, n_down):
            layers += [
                nn.Conv1d(channels[i-1], channels[i], 4, 2, 1),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[i]),
            ]
        self.main = nn.Sequential(*layers)
        self.main.apply(init_weight)
       
    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 1)
        outputs = self.main(inputs).permute(0, 2, 1)
        return outputs

class VQDecoder(nn.Module):
    def __init__(self, args):
        super(VQDecoderV3, self).__init__()
        channels = [args.vae_length, 128, args.vae_test_dim]
        n_up = args.vae_layer
        input_size = args.vae_length
        n_resblk = 2
        assert len(channels) == n_up + 1
        if input_size == channels[0]:
            layers = []
        else:
            layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]

        for i in range(n_resblk):
            layers += [ResBlock(channels[0])]
        for i in range(n_up):
            layers += [
                nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv1d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            ]
        layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
        self.main = nn.Sequential(*layers)
        self.main.apply(init_weight)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 1)
        outputs = self.main(inputs).permute(0, 2, 1)
        return outputs

class EGAE(nn.Module):
    def __init__(self, args):
        super(EGAE, self).__init__()
        self.vq_encoder = VQEncoder(args)
        self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda)
        self.vq_decoder = VQDecoder(args)
    
    def gumbel_softmax(self, logits, temperature=1.0):
        gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10) + 1e-10)
        y = logits + gumbel_noise
        return F.softmax(y / temperature, dim=-1)
    
    def forward(self, inputs):
        pre_latent = self.vq_encoder(inputs)
        embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent)
        rec_pose = self.vq_decoder(vq_latent)
        
        similarities = torch.matmul(pre_latent, self.quantizer.embedding.weight.t())
        soft_assignments = self.gumbel_softmax(similarities)
        return {
            "poses_feat":vq_latent,
            "embedding_loss":embedding_loss,
            "perplexity":perplexity,
            "x_recon": rec_pose,
            "probs": soft_assignments,
            }
    
    def map2index(self, inputs):
        pre_latent = self.vq_encoder(inputs)
        index = self.quantizer.map2index(pre_latent)
        return index
    
    def decode(self, index):
        z_q = self.quantizer.get_codebook_entry(index)
        rec_pose = self.vq_decoder(z_q)
        return rec_pose

@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output  

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=10000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.shape[1], :]
        return self.dropout(x)     
    
class Encoder_TRANSFORMER(nn.Module):
    def __init__(self, in_dim, out_dim, period, num_layers, dropout):
        super().__init__()
        self.latent_dim = out_dim
        self.ff_size = 1024
        self.num_layers = num_layers#args.num_layers
        self.num_heads = 4#args.num_heads
        self.dropout = dropout#args.dropout
        #self.ablation = ablation
        self.activation = "gelu"
        self.input_feats = in_dim #self.njoints*self.nfeats
        self.skelEmbedding = nn.Linear(self.input_feats, self.latent_dim)
        if period == "sin":
            self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
        else:
            self.sequence_pos_encoder = PeriodicPositionalEncoding(self.latent_dim, self.dropout, int(period))
        # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
                                                          nhead=self.num_heads,
                                                          dim_feedforward=self.ff_size,
                                                          dropout=self.dropout,
                                                          activation=self.activation,
                                                          batch_first=True)
        self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
                                                     num_layers=self.num_layers)
        
    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def forward(self, x): #bs * n * 141
        device = x.device
        zero_pad = torch.zeros(x.shape[0], 1, x.shape[2]).to(device)
        x = torch.cat([zero_pad, x], dim=1)
        x = self.skelEmbedding(x)
        xseq = self.sequence_pos_encoder(x)
        mask = None
        final = self.seqTransEncoder(xseq, mask)  
        return final[:, 1:, :], final[:, 0, :]

class Decoder_TRANSFORMER(nn.Module):
    def __init__(self, in_dim, out_dim, period, num_layers, dropout):
        super().__init__()
        self.latent_dim = in_dim
        self.ff_size = 1024
        self.num_layers = num_layers#args.num_layers
        self.num_heads = 4#args.num_heads
        self.dropout = dropout#args.dropout
        self.activation = "gelu"
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.LeakyReLU(0.2, True),
            nn.Linear(in_dim, out_dim)
        )
        if period == "sin":
            self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
        else:
            self.sequence_pos_encoder = PeriodicPositionalEncoding(self.latent_dim, self.dropout, int(period))
        seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim,
                                                          nhead=self.num_heads,
                                                          dim_feedforward=self.ff_size,
                                                          dropout=self.dropout,
                                                          activation="gelu",
                                                          batch_first=True,
                                                         )
        self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer,
                                                     num_layers=self.num_layers)
    
    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def forward(self, x, merged_features, teacher_forcing, mask=None, cat=True, af_start=1):
        if teacher_forcing:
            xseq = self.sequence_pos_encoder(x)
            mask = self._generate_square_subsequent_mask(xseq.size(1)).to(x.device)
            output = self.seqTransDecoder(tgt=xseq, memory=merged_features, tgt_mask=mask)
        else: 
            for i in range(af_start-1, x.shape[1]):
                if i == (af_start-1):
                    new_input = x[:, 0:af_start, :]
                else:
                    new_input = torch.cat((new_input, output[:, -1:, :]), 1)
                xseq = self.sequence_pos_encoder(new_input)    
                output = self.seqTransDecoder(xseq, merged_features) 
        output = self.mlp(output)
        return output
    
class WavEncoder(nn.Module):
    def __init__(self, out_dim):
        super().__init__() 
        self.out_dim = out_dim
        self.feat_extractor = nn.Sequential( 
                BasicBlock(1, out_dim//4, 15, 5, first_dilation=1600, downsample=True),
                BasicBlock(out_dim//4, out_dim//4, 15, 6, first_dilation=0, downsample=True),
                BasicBlock(out_dim//4, out_dim//4, 15, 1, first_dilation=7, ),
                BasicBlock(out_dim//4, out_dim//2, 15, 6, first_dilation=0, downsample=True),
                BasicBlock(out_dim//2, out_dim//2, 15, 1, first_dilation=7),
                BasicBlock(out_dim//2, out_dim, 15, 6,  first_dilation=0,downsample=True),     
            )
    def forward(self, wav_data):
        wav_data = wav_data.unsqueeze(1) 
        out = self.feat_extractor(wav_data)
        return out.transpose(1, 2) 

class LSTMMLP(nn.Module):
    def __init__(self, in_dim, hidden_size, out_dim, num_layers, dropout):
        super().__init__()
        self.lstm = nn.LSTM(in_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True, dropout=dropout)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(0.2, True),
            nn.Linear(hidden_size, out_dim)
        )
        self.hidden_size = hidden_size
    def forward(self, inputs):
        out, hidden = self.lstm(inputs)
        out = out[:, :, :self.hidden_size] + out[:, :, self.hidden_size:]
        out = self.mlp(out)
        hidden = torch.mean(hidden[0],dim=0)#avgpooling
        return out, hidden
    
class LP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.mlp = nn.Linear(in_dim, out_dim)
    def forward(self, inputs):
        out = self.mlp(inputs) #bs*n*128
        return out, torch.max(out, dim=1)[0] 

class Empty(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, inputs):
        return inputs
    
class VQMotion(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.a_linear = self.a_linear_m = self.t_linear = self.t_linear_m = self.m_linear = self.m_linear_m =  None
        # ----------------- audio ------------------------------- #
        if args.a_pre_encoder == "camn":
            assert args.audio_rep == "wave16k"
            self.audio_pre_encoder = WavEncoder(args.audio_f)
            self.audio_pre_encoder_m = WavEncoder(args.audio_f)
        elif args.a_pre_encoder == "ha2g":
            assert "spec" in args.audio_rep
            self.audio_pre_encoder = WavEncoder(args.audio_f)
            self.audio_pre_encoder_m = WavEncoder(args.audio_f)
        elif args.a_pre_encoder == "wav2vec2":  
            if args.a_fix_pre:
                self.audio_pre_encoder = nn.Linear(768, args.audio_f)
                self.audio_pre_encoder_m = nn.Linear(768, args.audio_f)
            else:
                self.audio_pre_encoder = Wav2Vec2Model.from_pretrained("/home/ma-user/work/datasets/hub/transformer/wav2vec2-base-960h")
                self.audio_pre_encoder_m = Wav2Vec2Model.from_pretrained("/home/ma-user/work/datasets/hub/transformer/wav2vec2-base-960h")
                self.a_linear = nn.Linear(768, args.audio_f)
                self.a_linear_m = nn.Linear(768, args.audio_f)
        else:
            raise NotImplementedError
            
        if args.a_encoder == "transformer":
            self.audio_encoder = Encoder_TRANSFORMER(args.audio_f, args.audio_f, args.pos_encoding_type, args.n_layer, args.dropout_prob)
            self.audio_encoder_m = Encoder_TRANSFORMER(args.audio_f, args.audio_f, args.pos_encoding_type, args.n_layer, args.dropout_prob)
        elif args.a_encoder == "lstm":
            self.audio_encoder = LSTMMLP(args.audio_f, args.audio_f, args.audio_f, args.n_layer, args.dropout_prob)
            self.audio_encoder_m = LSTMMLP(args.audio_f, args.audio_f, args.audio_f, args.n_layer, args.dropout_prob)
        elif args.a_encoder == "lp":
            self.audio_encoder = LP(args.audio_f, args.audio_f)
            self.audio_encoder_m = LP(args.audio_f, args.audio_f)
        else:
            raise NotImplementedError
        
        # ------------------------------- text ------------------------------------- #
        if args.t_pre_encoder == "bert":
            pre_trained_embedding = None
            with open(f"{args.data_path}{args.train_data_path[:-6]}vocab.pkl", 'rb') as f:
                self.lang_model = pickle.load(f)
                pre_trained_embedding = self.lang_model.word_embedding_weights
            self.text_pre_encoder = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding),freeze=args.t_fix_pre)
            self.text_pre_encoder_m = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding),freeze=args.t_fix_pre)
            self.t_linear = nn.Linear(300, args.word_f)
            self.t_linear_m = nn.Linear(300, args.word_f)
        else:
            raise NotImplementedError
            
        if args.t_encoder == "transformer":
            self.text_encoder = Encoder_TRANSFORMER(args.word_f, args.word_f, args.pos_encoding_type, args.n_layer, args.dropout_prob)
            self.text_encoder_m = Encoder_TRANSFORMER(args.word_f, args.word_f, args.pos_encoding_type, args.n_layer, args.dropout_prob)
        elif args.t_encoder == "lstm":
            self.text_encoder = LSTMMLP(args.word_f, args.word_f, args.word_f, args.n_layer, args.dropout_prob)
            self.text_encoder_m = LSTMMLP(args.word_f, args.word_f, args.word_f, args.n_layer, args.dropout_prob)
        elif args.t_encoder == "tcn":
            self.text_encoder = TextEncoderTCN(args, embed_size=args.word_f)
            self.text_encoder_m = TextEncoderTCN(args, embed_size=args.word_f)
        elif args.t_encoder == "lp":
            self.text_encoder = LP(args.word_f, args.word_f)
            self.text_encoder_m = LP(args.word_f, args.word_f)
        else:
            raise NotImplementedError
        
        # ------------------------ motion ---------------------------------- #
        if args.m_pre_encoder == "none":
            self.motion_pre_encoder = Empty()
            self.motion_pre_encoder_m = Empty()
            if args.m_decoder == "lstm":
                motion_in_f = args.pose_dims+1
            else:
                motion_in_f = args.pose_dims
        elif args.m_pre_encoder == "lp":
            if args.m_decoder == "lstm":
                self.motion_pre_encoder = nn.Linear(args.pose_dims+1, args.motion_f)
                self.motion_pre_encoder_m = nn.Linear(args.pose_dims+1, args.motion_f)
            else:
                self.motion_pre_encoder = nn.Linear(args.pose_dims, args.motion_f)
                self.motion_pre_encoder_m = nn.Linear(args.pose_dims, args.motion_f)
            motion_in_f = args.motion_f
        else: raise NotImplementedError
        
        if args.m_encoder == "transformer":
            self.motion_encoder = Encoder_TRANSFORMER(motion_in_f, args.motion_f, args.pos_encoding_type, args.n_layer, args.dropout_prob)
            self.motion_encoder_m = Encoder_TRANSFORMER(motion_in_f, args.motion_f, args.pos_encoding_type, args.n_layer, args.dropout_prob)
        elif args.m_encoder == "lstm":
            self.motion_encoder = LSTMMLP(motion_in_f, args.motion_f, args.motion_f, args.n_layer, args.dropout_prob)
            self.motion_encoder_m = LSTMMLP(motion_in_f, args.motion_f, args.motion_f, args.n_layer, args.dropout_prob)
        elif args.m_encoder == "lp":
            self.motion_encoder = LP(motion_in_f, args.motion_f)
            self.motion_encoder_m = LP(motion_in_f, args.motion_f)
        elif args.m_encoder == "transformer_stage":
            self.motion_encoder = Encoder_TRANSFORMER_stage2(motion_in_f, args.motion_f, args.pos_encoding_type, args.n_layer, args.dropout_prob)
            self.motion_encoder_m = Encoder_TRANSFORMER_stage2(motion_in_f, args.motion_f, args.pos_encoding_type, args.n_layer, args.dropout_prob)
        elif args.m_encoder == "tcn":
            self.motion_encoder = VQEncoderV3(motion_in_f, args.motion_f)
            self.motion_encoder_m = VQEncoderV3(motion_in_f, args.motion_f)
        elif args.m_encoder == "tcn_stage":
            self.motion_encoder = VQEncoderV3_stage2(motion_in_f, args.motion_f)
            self.motion_encoder_m = VQEncoderV3_stage2(motion_in_f, args.motion_f)
        else:
            raise NotImplementedError
        
        
        # ----------------------- motion decoder --------------------- #
        if args.m_decoder == "lstm":
            if "cat" in args.decode_fusion:
                decode_f = motion_in_f+args.audio_f+args.word_f
            elif "add" in args.decode_fusion:
                assert args.audio_f == args.word_f
                decode_f = motion_in_f+args.audio_f
            self.motion_decoder = LSTMMLP(decode_f, args.hidden_size, args.pose_dims, args.n_layer, args.dropout_prob)
        elif args.m_decoder == "transformer":
            if "cat" in args.decode_fusion:
                decode_f = args.audio_f+args.word_f
            elif "add" in args.decode_fusion:
                assert args.audio_f == args.word_f
                decode_f = args.audio_f
            self.motion_decoder = Decoder_TRANSFORMER(decode_f, args.pose_dims, args.pos_encoding_type, args.n_layer, args.dropout_prob)
            
        self.decode_fusion = args.decode_fusion
        self.m_decoder = args.m_decoder
        self.test_counter = 0
        self.weight_save = args.root_path+args.out_path + "custom/" + args.name + args.notes + "/"       

class VQMotionClassifier(VQMotion):
    def __init__(self, args):
        super(VQMotion, self).__init__(args)
        self.avg_pool = nn.AvgPool1d(4, stride=4, padding=0, ceil_mode=False, count_include_pad=True)
        self.cls_lp = nn.Linear(self.args.pose_dims, self.args.vae_codebook_size)
        
    def forward(self, seed_motion=None, in_audio=None, in_word=None, is_test="test", teacher_forcing=False):
        audio_feat_seq = self.audio_pre_encoder(in_audio) if self.a_linear is None else self.a_linear(self.audio_pre_encoder(in_audio))
        audio_feat_seq_with_pad, cls_a = self.audio_encoder(audio_feat_seq) # bs * 256
        text_feat_seq = self.text_pre_encoder(in_word) if self.t_linear is None else self.t_linear(self.text_pre_encoder(in_word))
        text_feat_seq_with_pad, cls_t = self.text_encoder(text_feat_seq) # bs * 256

        if audio_feat_seq_with_pad.shape[1] != seed_motion.shape[1]:
            diff_length = seed_motion.shape[1] - audio_feat_seq_with_pad.shape[1]
            audio_feat_seq_with_pad = torch.cat((audio_feat_seq_with_pad, audio_feat_seq_with_pad[:,-diff_length:, :].reshape(audio_feat_seq_with_pad.shape[0],diff_length,-1)),1)
        if text_feat_seq_with_pad.shape[1] != seed_motion.shape[1]:
            diff_length = seed_motion.shape[1] - text_feat_seq_with_pad.shape[1]
            text_feat_seq_with_pad = torch.cat((text_feat_seq_with_pad, text_feat_seq_with_pad[:,-diff_length:, :].reshape(text_feat_seq_with_pad.shape[0],diff_length,-1)),1)
        
        fusion_feat_seq = torch.cat((audio_feat_seq_with_pad, text_feat_seq_with_pad), 2) # bs * n * 256
        pre_motion_feat_seq = self.motion_pre_encoder(seed_motion) if self.m_linear is None else self.m_linear(self.motion_pre_encoder(seed_motion))
        if self.m_decoder == "lstm":
            output, _ = self.motion_decoder(fusion_feat_seq) # bs*t*n
        else:
            zero_pad = torch.zeros(fusion_feat_seq.shape[0], fusion_feat_seq.shape[1], fusion_feat_seq.shape[2]).to(fusion_feat_seq.device)
            output = self.motion_decoder(zero_pad, fusion_feat_seq, teacher_forcing)
        output_p = self.avg_pool(output.permute(0, 2, 1)) # bs*n*t/4
        output_f = self.cls_lp(output_p.permute(0, 2, 1))
        return {"rec_pose":output_f}
        